from Causal.ac_infer.Environment.environment import non_state_factors, strip_instance
import numpy as np

class FlatNormalization:
    def __init__(self, all_names, lim_dict, goal_lims):
        self.flat_lims = (np.concatenate([lim_dict[strip_instance(n)][0] for n in all_names if n not in non_state_factors]),
                          np.concatenate([lim_dict[strip_instance(n)][1] for n in all_names if n not in non_state_factors]))
        self.goal_lims = goal_lims
        self.goal_mean = (goal_lims[1] + goal_lims[0]) / 2
        self.goal_var = (goal_lims[1] - goal_lims[0]) / 2
        self.mean = (self.flat_lims[1] + self.flat_lims[0]) / 2
        self.var = (self.flat_lims[1] - self.flat_lims[0]) / 2

    def normalize_obs(self, obs):
        if len(obs.shape) == len(self.mean.shape):
            return (obs - self.mean) / self.var
        return (obs - np.expand_dims(self.mean, axis=0) ) / np.expand_dims(self.var, axis=0)

    def normalize_goal(self, obs):
        if len(obs.shape) == len(self.goal_mean.shape):
            return (obs - self.goal_mean) / self.goal_var
        return (obs - np.expand_dims(self.goal_mean, axis=0) ) / np.expand_dims(self.goal_var, axis=0)

    def denormalize_obs(self, obs):
        if len(obs.shape) == len(self.mean.shape):
            return obs * self.var + self.mean
        return obs * np.expand_dims(self.var, axis=0) + np.expand_dims(self.mean, axis=0)

    def denormalize_goal(self, obs):
        if len(obs.shape) == len(self.goal_mean.shape):
            return obs * self.goal_var + self.goal_mean
        return obs * np.expand_dims(self.goal_var, axis=0) + np.expand_dims(self.goal_mean, axis=0)
